"""
    To run validation
"""
import os, csv, json
import random
import argparse
import numpy as np
from tqdm import tqdm

# torch modules
import torch
import torch.backends.cudnn as cudnn
from torch.autograd import Variable

# custom libs
from utils.datasets import load_dataset
from utils.networks import load_network, load_trained_network
from utils.optims import define_loss_function, define_optimizer


# ------------------------------------------------------------------------------
#   Run training
# ------------------------------------------------------------------------------
def run_validation(args):

    # set if cuda is unavailable
    if not torch.cuda.is_available(): args.cuda = False

    # init. the random seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda: torch.cuda.manual_seed(args.seed)

    # set the CUDNN backend as deterministic
    if args.cuda: cudnn.deterministic = True

    # init. dataset (train/test)
    kwargs = {
            'num_workers': args.num_workers,
            'pin_memory' : args.pin_memory
        } if args.cuda else {}
    train_loader, valid_loader = load_dataset( \
            args.dataset, args.datapth, args.batch_size, kwargs)
    print (' : Load the dataset [{}] from [{}]'.format(args.dataset, args.datapth))

    # init. the network
    network = load_network(args.dataset, args.network)
    assert args.trained, "Error: provide the model filepath, abort."
    load_trained_network(network, args.cuda, args.trained)
    if args.cuda: network.cuda()
    print (' : Define a network [{}]'.format(type(network).__name__))

    # init. loss function
    task_loss = define_loss_function(args.lossfunc)
    print (' : Define a loss function [{}]'.format(args.lossfunc))

    # run validation...
    vacc, vloss = valid(args, "N/A", network, valid_loader, task_loss)

    print (': Done, validation')
    # done.


# ------------------------------------------------------------------------------
#   Valid functions
# ------------------------------------------------------------------------------
def valid(args, epoch, net, valid_loader, taskloss):
    # test
    net.eval()

    # data holders
    valid_corr = 0
    valid_loss = 0.

    # loop over the test dataset
    for data, labels in tqdm(valid_loader, desc='[{}]'.format(epoch)):
        if args.cuda:
            data, labels = data.cuda(), labels.cuda()
        data, labels = Variable(data, requires_grad=False), Variable(labels)
        with torch.no_grad():
            output = net(data)

            # compute loss and acc
            predict  = output.data.max(1, keepdim=True)[1]
            valid_corr += predict.eq(labels.data.view_as(predict)).cpu().sum().item()
            valid_loss += taskloss(output, labels, reduction='sum').data.item()

    # the total loss and accuracy
    valid_loss /= len(valid_loader.dataset)
    valid_acc   = 100. * valid_corr / len(valid_loader.dataset)

    # report the result
    print('  Epoch: {} [{}/{} (Acc: {:.4f}%)]\tAverage loss: {:.6f}'.format(
        epoch, valid_corr, len(valid_loader.dataset), valid_acc, valid_loss))

    # return acc and loss
    return valid_acc, valid_loss


"""
    Main (to check the acc. of the trained models)
"""
if __name__ == '__main__':
    parser = argparse.ArgumentParser( \
        description='Run validation')

    # system parameters
    parser.add_argument('--seed', type=int, default=215,
                        help='random seed (default: 215)')
    parser.add_argument('--cuda', action='store_true',
                        help='enables CUDA training')
    parser.add_argument('--num-workers', type=int, default=4,
                        help='number of workers (default: 4)')
    parser.add_argument('--pin-memory', action='store_false',
                        help='the data loader copies tensors into CUDA pinned memory')

    # dataset parameters
    parser.add_argument('--dataset', type=str, default='cifar10',
                        help='dataset used to train: cifar10.')
    parser.add_argument('--datapth', type=str, default='',
                        help='dataset location (which uses an processed file)')

    # model parameters
    parser.add_argument('--network', type=str, default='ConvNet',
                        help='model name (default: ConvNet).')
    parser.add_argument('--trained', type=str, default='',
                        help='pre-trained model filepath.')
    parser.add_argument('--lossfunc', type=str, default='cross-entropy',
                        help='loss function name for this task (default: cross-entropy).')
    parser.add_argument('--classes', type=int, default=10,
                        help='number of classes in the dataset (ex. 10 in CIFAR10).')
    parser.add_argument('--batch-size', type=int, default=125,
                        help='input batch size for training (default: 125)')

    # execution parameters
    args = parser.parse_args()
    print (json.dumps(vars(args), indent=2))

    # run the training
    run_validation(args)
    # Fin.
